# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

from __future__ import annotations

import os
import pathlib
import shutil
import sys
from dataclasses import dataclass, field
from typing import Any, Final

import pytest
from immutabledict import immutabledict

from crossbench import config
from crossbench.cli import exception_formatter

root_dir = config.root_dir
config_dir = config.config_dir


def crossbench_dir() -> pathlib.Path:
  return root_dir() / "crossbench"


def is_on_swarming():
  return "SWARMING_SERVER" in os.environ


@dataclass(frozen=True)
class TestEnv():
  # Avoid getting PytestCollectionWarning as the class name starts with Test.
  __test__ = False

  output_dir: pathlib.Path
  test_name: str
  is_cq: bool = field(init=False)
  cq_flags: tuple[str, ...] = field(init=False)
  archive_dir: pathlib.Path = field(init=False)
  results_dir: pathlib.Path = field(init=False)
  root_dir: pathlib.Path = field(init=False)

  def __post_init__(self):
    output_path = pathlib.Path(self.output_dir)
    self._set("is_cq", output_path.parts[-1].startswith("cq_archive_"))
    self._set("cq_flags", ("--no-symlinks",) if self.is_cq else ())
    run_seq = 0
    while True:
      self._set("output_dir", output_path / self.test_name / str(run_seq))
      if not self.output_dir.exists():
        break
      run_seq += 1
    self.output_dir.mkdir(parents=True)
    self._set("archive_dir", self.output_dir / "browser_archive")
    assert not self.archive_dir.exists()
    self._set("results_dir", self.output_dir / "results")
    assert not self.results_dir.exists()
    self._set("root_dir", root_dir())

  def _set(self, attr: str, value: Any):
    object.__setattr__(self, attr, value)

  def remove_non_result(self):
    for output_path in self.output_dir.iterdir():
      if output_path.is_dir() and output_path != self.results_dir:
        shutil.rmtree(output_path)

  def assert_empty_output_dir(self):
    assert not tuple(self.output_dir.glob("**/*"))


DEFAULT_PYTEST_FLAGS: Final[immutabledict[str, str | None]] = immutabledict({
    "--verbose": None,
    "--log-file-level": "DEBUG",
    "--durations": 5,
    "--no-fold-skipped": None,
    "-r": "s",
})


def to_flags(flag_dict):
  for k, v in flag_dict.items():
    if v:
      yield f"{k}={v}"
    else:
      yield k


def run_pytest(path: str | pathlib.Path, *args):
  sys.excepthook = exception_formatter.excepthook
  extra_args = [*args, *sys.argv[1:]]
  # Run tests single-threaded by default when running the test file directly.
  if "-n" not in extra_args:
    extra_args.extend(["-n", "1"])
  if "-r" not in extra_args:
    extra_args.extend(["-r", "s"])
  sys.exit(pytest.main([str(path), *extra_args]))
